import asyncio
import fnmatch
import glob
import json
import mimetypes
import os
import re
import shutil
import urllib.parse
import webbrowser
from contextlib import AsyncExitStack, asynccontextmanager
from pathlib import Path
from typing import List, Optional, Union, cast

import socketio
from fastapi import (
    APIRouter,
    Depends,
    FastAPI,
    Form,
    HTTPException,
    Query,
    Request,
    Response,
    UploadFile,
    status,
)
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm
from starlette.datastructures import URL
from starlette.middleware.cors import CORSMiddleware
from starlette.types import Receive, Scope, Send
from typing_extensions import Annotated
from watchfiles import awatch

from chainlit.auth import create_jwt, decode_jwt, get_configuration, get_current_user
from chainlit.auth.cookie import (
    clear_auth_cookie,
    clear_oauth_state_cookie,
    set_auth_cookie,
    set_oauth_state_cookie,
    validate_oauth_state_cookie,
)
from chainlit.config import (
    APP_ROOT,
    BACKEND_ROOT,
    DEFAULT_HOST,
    FILES_DIRECTORY,
    PACKAGE_ROOT,
    ChainlitConfig,
    config,
    load_module,
    public_dir,
    reload_config,
)
from chainlit.data import get_data_layer
from chainlit.data.acl import is_thread_author
from chainlit.logger import logger
from chainlit.markdown import get_markdown_str
from chainlit.oauth_providers import get_oauth_provider
from chainlit.secret import random_secret
from chainlit.types import (
    AskFileSpec,
    CallActionRequest,
    ConnectMCPRequest,
    DeleteFeedbackRequest,
    DeleteThreadRequest,
    DisconnectMCPRequest,
    ElementRequest,
    GetThreadsRequest,
    ShareThreadRequest,
    Theme,
    UpdateFeedbackRequest,
    UpdateThreadRequest,
)
from chainlit.user import PersistedUser, User
from chainlit.utils import utc_now

from ._utils import is_path_inside

mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")


@asynccontextmanager
async def lifespan(app: FastAPI):
    """Context manager to handle app start and shutdown."""
    if config.code.on_app_startup:
        await config.code.on_app_startup()

    host = config.run.host
    port = config.run.port
    root_path = os.getenv("CHAINLIT_ROOT_PATH", "")

    if host == DEFAULT_HOST:
        url = f"http://localhost:{port}{root_path}"
    else:
        url = f"http://{host}:{port}{root_path}"

    logger.info(f"Your app is available at {url}")

    if not config.run.headless:
        # Add a delay before opening the browser
        await asyncio.sleep(1)
        webbrowser.open(url)

    watch_task = None
    stop_event = asyncio.Event()

    if config.run.watch:

        async def watch_files_for_changes():
            extensions = [".py"]
            files = ["chainlit.md", "config.toml"]
            async for changes in awatch(config.root, stop_event=stop_event):
                for change_type, file_path in changes:
                    file_name = os.path.basename(file_path)
                    file_ext = os.path.splitext(file_name)[1]

                    if file_ext.lower() in extensions or file_name.lower() in files:
                        logger.info(
                            f"File {change_type.name}: {file_name}. Reloading app..."
                        )

                        try:
                            reload_config()
                        except Exception as e:
                            logger.error(f"Error reloading config: {e}")
                            break

                        # Reload the module if the module name is specified in the config
                        if config.run.module_name:
                            try:
                                load_module(config.run.module_name, force_refresh=True)
                            except Exception as e:
                                logger.error(f"Error reloading module: {e}")

                        await asyncio.sleep(1)
                        await sio.emit("reload", {})

                        break

        watch_task = asyncio.create_task(watch_files_for_changes())

    discord_task = None

    if discord_bot_token := os.environ.get("DISCORD_BOT_TOKEN"):
        from chainlit.discord.app import client

        discord_task = asyncio.create_task(client.start(discord_bot_token))

    slack_task = None

    # Slack Socket Handler if env variable SLACK_WEBSOCKET_TOKEN is set
    if os.environ.get("SLACK_BOT_TOKEN") and os.environ.get("SLACK_WEBSOCKET_TOKEN"):
        from chainlit.slack.app import start_socket_mode

        slack_task = asyncio.create_task(start_socket_mode())

    try:
        yield
    finally:
        try:
            if config.code.on_app_shutdown:
                await config.code.on_app_shutdown()

            if watch_task:
                stop_event.set()
                watch_task.cancel()
                await watch_task

            if discord_task:
                discord_task.cancel()
                await discord_task

            if slack_task:
                slack_task.cancel()
                await slack_task

            if data_layer := get_data_layer():
                await data_layer.close()
        except asyncio.exceptions.CancelledError:
            pass

        if FILES_DIRECTORY.is_dir():
            shutil.rmtree(FILES_DIRECTORY)

        # Force exit the process to avoid potential AnyIO threads still running
        os._exit(0)


def get_build_dir(local_target: str, packaged_target: str) -> str:
    """
    Get the build directory based on the UI build strategy.

    Args:
        local_target (str): The local target directory.
        packaged_target (str): The packaged target directory.

    Returns:
        str: The build directory
    """

    local_build_dir = os.path.join(PACKAGE_ROOT, local_target, "dist")
    packaged_build_dir = os.path.join(BACKEND_ROOT, packaged_target, "dist")

    if config.ui.custom_build and os.path.exists(
        os.path.join(APP_ROOT, config.ui.custom_build)
    ):
        return os.path.join(APP_ROOT, config.ui.custom_build)
    elif os.path.exists(local_build_dir):
        return local_build_dir
    elif os.path.exists(packaged_build_dir):
        return packaged_build_dir
    else:
        raise FileNotFoundError(f"{local_target} built UI dir not found")


build_dir = get_build_dir("frontend", "frontend")
copilot_build_dir = get_build_dir(os.path.join("libs", "copilot"), "copilot")

app = FastAPI(lifespan=lifespan)

sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi")

asgi_app = socketio.ASGIApp(socketio_server=sio, socketio_path="")

# config.run.root_path is only set when started with --root-path. Not on submounts.
SOCKET_IO_PATH = f"{config.run.root_path}/ws/socket.io"
app.mount(SOCKET_IO_PATH, asgi_app)

app.add_middleware(
    CORSMiddleware,
    allow_origins=config.project.allow_origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


class SafariWebSocketsCompatibleGZipMiddleware(GZipMiddleware):
    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        if scope["type"] != "http":
            return await self.app(scope, receive, send)

        # Prevent gzip compression for HTTP requests to socket.io path due to a bug in Safari
        if URL(scope=scope).path.startswith(SOCKET_IO_PATH):
            await self.app(scope, receive, send)
        else:
            await super().__call__(scope, receive, send)


app.add_middleware(SafariWebSocketsCompatibleGZipMiddleware)

# config.run.root_path is only set when started with --root-path. Not on submounts.
router = APIRouter(prefix=config.run.root_path)


@router.get("/public/{filename:path}")
async def serve_public_file(
    filename: str,
):
    """Serve a file from public dir."""

    base_path = Path(public_dir)
    file_path = (base_path / filename).resolve()

    if not is_path_inside(file_path, base_path):
        raise HTTPException(status_code=400, detail="Invalid filename")

    if file_path.is_file():
        return FileResponse(file_path)
    else:
        raise HTTPException(status_code=404, detail="File not found")


@router.get("/assets/{filename:path}")
async def serve_asset_file(
    filename: str,
):
    """Serve a file from assets dir."""

    base_path = Path(os.path.join(build_dir, "assets"))
    file_path = (base_path / filename).resolve()

    if not is_path_inside(file_path, base_path):
        raise HTTPException(status_code=400, detail="Invalid filename")

    if file_path.is_file():
        return FileResponse(file_path)
    else:
        raise HTTPException(status_code=404, detail="File not found")


@router.get("/copilot/{filename:path}")
async def serve_copilot_file(
    filename: str,
):
    """Serve a file from assets dir."""

    base_path = Path(copilot_build_dir)
    file_path = (base_path / filename).resolve()

    if not is_path_inside(file_path, base_path):
        raise HTTPException(status_code=400, detail="Invalid filename")

    if file_path.is_file():
        return FileResponse(file_path)
    else:
        raise HTTPException(status_code=404, detail="File not found")


# -------------------------------------------------------------------------------
#                               SLACK HTTP HANDLER
# -------------------------------------------------------------------------------

if (
    os.environ.get("SLACK_BOT_TOKEN")
    and os.environ.get("SLACK_SIGNING_SECRET")
    and not os.environ.get("SLACK_WEBSOCKET_TOKEN")
):
    from chainlit.slack.app import slack_app_handler

    @router.post("/slack/events")
    async def slack_endpoint(req: Request):
        return await slack_app_handler.handle(req)


# -------------------------------------------------------------------------------
#                               TEAMS HANDLER
# -------------------------------------------------------------------------------

if os.environ.get("TEAMS_APP_ID") and os.environ.get("TEAMS_APP_PASSWORD"):
    from botbuilder.schema import Activity

    from chainlit.teams.app import adapter, bot

    @router.post("/teams/events")
    async def teams_endpoint(req: Request):
        body = await req.json()
        activity = Activity().deserialize(body)
        auth_header = req.headers.get("Authorization", "")
        response = await adapter.process_activity(activity, auth_header, bot.on_turn)
        return response


# -------------------------------------------------------------------------------
#                               HTTP HANDLERS
# -------------------------------------------------------------------------------


def replace_between_tags(
    text: str, start_tag: str, end_tag: str, replacement: str
) -> str:
    """Replace text between two tags in a string."""

    pattern = start_tag + ".*?" + end_tag
    return re.sub(pattern, start_tag + replacement + end_tag, text, flags=re.DOTALL)


def get_html_template(root_path):
    """
    Get HTML template for the index view.
    """
    root_path = root_path.rstrip("/")  # Avoid duplicated / when joining with root path.

    custom_theme = None
    custom_theme_file_path = Path(public_dir) / "theme.json"
    if (
        is_path_inside(custom_theme_file_path, Path(public_dir))
        and custom_theme_file_path.is_file()
    ):
        custom_theme = json.loads(custom_theme_file_path.read_text(encoding="utf-8"))

    PLACEHOLDER = "<!-- TAG INJECTION PLACEHOLDER -->"
    JS_PLACEHOLDER = "<!-- JS INJECTION PLACEHOLDER -->"
    CSS_PLACEHOLDER = "<!-- CSS INJECTION PLACEHOLDER -->"

    default_url = config.ui.custom_meta_url or "https://github.com/Chainlit/chainlit"
    default_meta_image_url = (
        "https://chainlit-cloud.s3.eu-west-3.amazonaws.com/logo/chainlit_banner.png"
    )
    meta_image_url = config.ui.custom_meta_image_url or default_meta_image_url
    favicon_path = "/favicon"

    tags = f"""<title>{config.ui.name}</title>
    <link rel="icon" href="{favicon_path}" />
    <meta name="description" content="{config.ui.description}">
    <meta property="og:type" content="website">
    <meta property="og:title" content="{config.ui.name}">
    <meta property="og:description" content="{config.ui.description}">
    <meta property="og:image" content="{meta_image_url}">
    <meta property="og:url" content="{default_url}">
    <meta property="og:root_path" content="{root_path}">"""

    js = f"""<script>
{f"window.theme = {json.dumps(custom_theme.get('variables'))};" if custom_theme and custom_theme.get("variables") else "undefined"}
{f"window.transports = {json.dumps(config.project.transports)};" if config.project.transports else "undefined"}
</script>"""

    css = None
    if config.ui.custom_css:
        css = f"""<link rel="stylesheet" type="text/css" href="{config.ui.custom_css}" {config.ui.custom_css_attributes}>"""

    if config.ui.custom_js:
        js += f"""<script src="{config.ui.custom_js}" {config.ui.custom_js_attributes}></script>"""

    font = None
    if custom_theme and custom_theme.get("custom_fonts"):
        font = "\n".join(
            f"""<link rel="stylesheet" href="{font}">"""
            for font in custom_theme.get("custom_fonts")
        )

    index_html_file_path = os.path.join(build_dir, "index.html")

    with open(index_html_file_path, encoding="utf-8") as f:
        content = f.read()
        content = content.replace(PLACEHOLDER, tags)
        if js:
            content = content.replace(JS_PLACEHOLDER, js)
        if css:
            content = content.replace(CSS_PLACEHOLDER, css)
        if font:
            content = replace_between_tags(
                content, "<!-- FONT START -->", "<!-- FONT END -->", font
            )
        content = content.replace('href="/', f'href="{root_path}/')
        content = content.replace('src="/', f'src="{root_path}/')
        return content


def get_user_facing_url(url: URL):
    """
    Return the user facing URL for a given URL.
    Handles deployment with proxies (like cloud run).
    """
    chainlit_url = os.environ.get("CHAINLIT_URL")

    # No config, we keep the URL as is
    if not chainlit_url:
        url = url.replace(query="", fragment="")
        return url.__str__()

    config_url = URL(chainlit_url).replace(
        query="",
        fragment="",
    )
    # Remove trailing slash from config URL
    if config_url.path.endswith("/"):
        config_url = config_url.replace(path=config_url.path[:-1])

    return config_url.__str__() + url.path


@router.get("/auth/config")
async def auth(request: Request):
    return get_configuration()


def _get_response_dict(access_token: str) -> dict:
    """Get the response dictionary for the auth response."""

    return {"success": True}


def _get_auth_response(access_token: str, redirect_to_callback: bool) -> Response:
    """Get the redirect params for the OAuth callback."""

    response_dict = _get_response_dict(access_token)

    if redirect_to_callback:
        root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
        root_path = "" if root_path == "/" else root_path
        redirect_url = (
            f"{root_path}/login/callback?{urllib.parse.urlencode(response_dict)}"
        )

        return RedirectResponse(
            # FIXME: redirect to the right frontend base url to improve the dev environment
            url=redirect_url,
            status_code=302,
        )

    return JSONResponse(response_dict)


def _get_oauth_redirect_error(request: Request, error: str) -> Response:
    """Get the redirect response for an OAuth error."""
    params = urllib.parse.urlencode(
        {
            "error": error,
        }
    )
    response = RedirectResponse(url=str(request.url_for("login")) + "?" + params)
    return response


async def _authenticate_user(
    request: Request, user: Optional[User], redirect_to_callback: bool = False
) -> Response:
    """Authenticate a user and return the response."""

    if not user:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="credentialssignin",
        )

    # If a data layer is defined, attempt to persist user.
    if data_layer := get_data_layer():
        try:
            await data_layer.create_user(user)
        except Exception as e:
            # Catch and log exceptions during user creation.
            # TODO: Make this catch only specific errors and allow others to propagate.
            logger.error(f"Error creating user: {e}")

    access_token = create_jwt(user)

    response = _get_auth_response(access_token, redirect_to_callback)

    set_auth_cookie(request, response, access_token)

    return response


@router.post("/login")
async def login(
    request: Request,
    response: Response,
    form_data: OAuth2PasswordRequestForm = Depends(),
):
    """
    Login a user using the password auth callback.
    """
    if not config.code.password_auth_callback:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
        )

    user = await config.code.password_auth_callback(
        form_data.username, form_data.password
    )

    return await _authenticate_user(request, user)


@router.post("/logout")
async def logout(request: Request, response: Response):
    """Logout the user by calling the on_logout callback."""
    clear_auth_cookie(request, response)

    if config.code.on_logout:
        return await config.code.on_logout(request, response)

    return {"success": True}


@router.post("/auth/jwt")
async def jwt_auth(request: Request):
    """Login a user using a valid jwt."""
    from jwt import InvalidTokenError

    auth_header: Optional[str] = request.headers.get("Authorization")
    if not auth_header:
        raise HTTPException(status_code=401, detail="Authorization header missing")

    # Check if it starts with "Bearer "
    try:
        scheme, token = auth_header.split()
        if scheme.lower() != "bearer":
            raise HTTPException(
                status_code=401,
                detail="Invalid authentication scheme. Please use Bearer",
            )
    except ValueError:
        raise HTTPException(
            status_code=401, detail="Invalid authorization header format"
        )

    try:
        user = decode_jwt(token)
        return await _authenticate_user(request, user)
    except InvalidTokenError:
        raise HTTPException(status_code=401, detail="Invalid token")


@router.post("/auth/header")
async def header_auth(request: Request):
    """Login a user using the header_auth_callback."""
    if not config.code.header_auth_callback:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="No header_auth_callback defined",
        )

    user = await config.code.header_auth_callback(request.headers)

    return await _authenticate_user(request, user)


@router.get("/auth/oauth/{provider_id}")
async def oauth_login(provider_id: str, request: Request):
    """Redirect the user to the oauth provider login page."""
    if config.code.oauth_callback is None:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="No oauth_callback defined",
        )

    provider = get_oauth_provider(provider_id)
    if not provider:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail=f"Provider {provider_id} not found",
        )

    random = random_secret(32)

    params = urllib.parse.urlencode(
        {
            "client_id": provider.client_id,
            "redirect_uri": f"{get_user_facing_url(request.url)}/callback",
            "state": random,
            **provider.authorize_params,
        }
    )
    response = RedirectResponse(
        url=f"{provider.authorize_url}?{params}",
    )

    set_oauth_state_cookie(response, random)

    return response


@router.get("/auth/oauth/{provider_id}/callback")
async def oauth_callback(
    provider_id: str,
    request: Request,
    error: Optional[str] = None,
    code: Optional[str] = None,
    state: Optional[str] = None,
):
    """Handle the oauth callback and login the user."""

    if config.code.oauth_callback is None:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="No oauth_callback defined",
        )

    provider = get_oauth_provider(provider_id)
    if not provider:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail=f"Provider {provider_id} not found",
        )

    if error:
        return _get_oauth_redirect_error(request, error)

    if not code or not state:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="Missing code or state",
        )

    try:
        validate_oauth_state_cookie(request, state)
    except Exception as e:
        logger.exception("Unable to validate oauth state: %1", e)

        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Unauthorized",
        )

    url = get_user_facing_url(request.url)
    token = await provider.get_token(code, url)

    (raw_user_data, default_user) = await provider.get_user_info(token)

    user = await config.code.oauth_callback(
        provider_id, token, raw_user_data, default_user
    )

    response = await _authenticate_user(request, user, redirect_to_callback=True)

    clear_oauth_state_cookie(response)

    return response


# specific route for azure ad hybrid flow
@router.post("/auth/oauth/azure-ad-hybrid/callback")
async def oauth_azure_hf_callback(
    request: Request,
    error: Optional[str] = None,
    code: Annotated[Optional[str], Form()] = None,
    id_token: Annotated[Optional[str], Form()] = None,
):
    """Handle the azure ad hybrid flow callback and login the user."""

    provider_id = "azure-ad-hybrid"
    if config.code.oauth_callback is None:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="No oauth_callback defined",
        )

    provider = get_oauth_provider(provider_id)
    if not provider:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail=f"Provider {provider_id} not found",
        )

    if error:
        return _get_oauth_redirect_error(request, error)

    if not code:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="Missing code",
        )

    url = get_user_facing_url(request.url)
    token = await provider.get_token(code, url)

    (raw_user_data, default_user) = await provider.get_user_info(token)

    user = await config.code.oauth_callback(
        provider_id, token, raw_user_data, default_user, id_token
    )

    response = await _authenticate_user(request, user, redirect_to_callback=True)

    clear_oauth_state_cookie(response)

    return response


GenericUser = Union[User, PersistedUser, None]
UserParam = Annotated[GenericUser, Depends(get_current_user)]


@router.get("/user")
async def get_user(current_user: UserParam) -> GenericUser:
    return current_user


_language_pattern = (
    "^[a-zA-Z]{2,3}(-[a-zA-Z0-9]{2,4})?(-[a-zA-Z0-9]{2,8})?(-x-[a-zA-Z0-9]{1,8})?$"
)


@router.post("/set-session-cookie")
async def set_session_cookie(request: Request, response: Response):
    body = await request.json()
    session_id = body.get("session_id")

    is_local = request.client and request.client.host in ["127.0.0.1", "localhost"]

    response.set_cookie(
        key="X-Chainlit-Session-id",
        value=session_id,
        path="/",
        httponly=True,
        secure=not is_local,
        samesite="lax" if is_local else "none",
    )

    return {"message": "Session cookie set"}


@router.get("/project/translations")
async def project_translations(
    language: str = Query(
        default="en-US", description="Language code", pattern=_language_pattern
    ),
):
    """Return project translations."""

    # Load translation based on the provided language
    translation = config.load_translation(language)

    return JSONResponse(
        content={
            "translation": translation,
        }
    )


@router.get("/project/settings")
async def project_settings(
    current_user: UserParam,
    language: str = Query(
        default="en-US", description="Language code", pattern=_language_pattern
    ),
    chat_profile: Optional[str] = Query(
        default=None, description="Current chat profile name"
    ),
):
    """Return project settings. This is called by the UI before the establishing the websocket connection."""

    # Load the markdown file based on the provided language
    markdown = get_markdown_str(config.root, language)

    chat_profiles = []
    profiles: list[dict] = []
    if config.code.set_chat_profiles:
        chat_profiles = await config.code.set_chat_profiles(current_user, language)
        if chat_profiles:
            for p in chat_profiles:
                d = p.to_dict()
                d.pop("config_overrides", None)
                profiles.append(d)

    starters = []
    if config.code.set_starters:
        s = await config.code.set_starters(current_user, language)
        if s:
            starters = [it.to_dict() for it in s]

    data_layer = get_data_layer()
    debug_url = (
        await data_layer.build_debug_url() if data_layer and config.run.debug else None
    )

    cfg = config
    if chat_profile and chat_profiles:
        current_profile = next(
            (p for p in chat_profiles if p.name == chat_profile), None
        )
        if current_profile and getattr(current_profile, "config_overrides", None):
            cfg = config.with_overrides(current_profile.config_overrides)

    return JSONResponse(
        content={
            "ui": cfg.ui.model_dump(),
            "features": cfg.features.model_dump(),
            "userEnv": cfg.project.user_env,
            "maskUserEnv": cfg.project.mask_user_env,
            "dataPersistence": data_layer is not None,
            "threadResumable": bool(config.code.on_chat_resume),
            # Expose whether shared threads feature is enabled (flag + app callback)
            "threadSharing": bool(
                getattr(cfg.features, "allow_thread_sharing", False)
                and getattr(config.code, "on_shared_thread_view", None)
            ),
            "markdown": markdown,
            "chatProfiles": profiles,
            "starters": starters,
            "debugUrl": debug_url,
        }
    )


@router.put("/feedback")
async def update_feedback(
    request: Request,
    update: UpdateFeedbackRequest,
    current_user: UserParam,
):
    """Update the human feedback for a particular message."""
    data_layer = get_data_layer()
    if not data_layer:
        raise HTTPException(status_code=500, detail="Data persistence is not enabled")

    try:
        feedback_id = await data_layer.upsert_feedback(feedback=update.feedback)

        if config.code.on_feedback:
            try:
                from chainlit.context import init_ws_context
                from chainlit.session import WebsocketSession

                session = WebsocketSession.get_by_id(update.sessionId)
                init_ws_context(session)

                await config.code.on_feedback(update.feedback)
            except Exception as callback_error:
                logger.error(
                    f"Error in user-provided on_feedback callback: {callback_error}"
                )
                # Optionally, you could continue without raising an exception to avoid disrupting the endpoint.
    except Exception as e:
        raise HTTPException(detail=str(e), status_code=500) from e

    return JSONResponse(content={"success": True, "feedbackId": feedback_id})


@router.delete("/feedback")
async def delete_feedback(
    request: Request,
    payload: DeleteFeedbackRequest,
    current_user: UserParam,
):
    """Delete a feedback."""

    data_layer = get_data_layer()

    if not data_layer:
        raise HTTPException(status_code=400, detail="Data persistence is not enabled")

    feedback_id = payload.feedbackId

    await data_layer.delete_feedback(feedback_id)
    return JSONResponse(content={"success": True})


@router.post("/project/threads")
async def get_user_threads(
    request: Request,
    payload: GetThreadsRequest,
    current_user: UserParam,
):
    """Get the threads page by page."""

    data_layer = get_data_layer()

    if not data_layer:
        raise HTTPException(status_code=400, detail="Data persistence is not enabled")

    if not current_user:
        raise HTTPException(status_code=401, detail="Unauthorized")

    if not isinstance(current_user, PersistedUser):
        persisted_user = await data_layer.get_user(identifier=current_user.identifier)
        if not persisted_user:
            raise HTTPException(status_code=404, detail="User not found")
        payload.filter.userId = persisted_user.id
    else:
        payload.filter.userId = current_user.id

    res = await data_layer.list_threads(payload.pagination, payload.filter)
    return JSONResponse(content=res.to_dict())


@router.get("/project/thread/{thread_id}")
async def get_thread(
    request: Request,
    thread_id: str,
    current_user: UserParam,
):
    """Get a specific thread."""
    data_layer = get_data_layer()

    if not data_layer:
        raise HTTPException(status_code=400, detail="Data persistence is not enabled")

    if not current_user:
        raise HTTPException(status_code=401, detail="Unauthorized")

    await is_thread_author(current_user.identifier, thread_id)

    res = await data_layer.get_thread(thread_id)
    return JSONResponse(content=res)


@router.get("/project/share/{thread_id}")
async def get_shared_thread(
    request: Request,
    thread_id: str,
    current_user: UserParam,
):
    """Get a shared thread (read-only for everyone).

    This endpoint is separate from the resume endpoint and does not require the caller
    to be the author of the thread. It only returns the thread if its metadata
    contains is_shared=True. Otherwise, it returns 404 to avoid leaking existence.
    """

    data_layer = get_data_layer()

    if not data_layer:
        raise HTTPException(status_code=400, detail="Data persistence is not enabled")

    # No auth required: allow anonymous access to shared threads
    thread = await data_layer.get_thread(thread_id)

    if not thread:
        raise HTTPException(status_code=404, detail="Thread not found")
    # Extract and normalize metadata (may be dict, strified JSON, or None)
    metadata = (thread.get("metadata") if isinstance(thread, dict) else {}) or {}
    if isinstance(metadata, str):
        try:
            metadata = json.loads(metadata)
        except Exception:
            metadata = {}
    if not isinstance(metadata, dict):
        metadata = {}

    if getattr(config.code, "on_shared_thread_view", None):
        try:
            user_can_view = await config.code.on_shared_thread_view(
                thread, current_user
            )
        except Exception:
            user_can_view = False

    is_shared = bool(metadata.get("is_shared"))

    # Proceed only raise an error if both conditions are False.
    if (not user_can_view) and (not is_shared):
        raise HTTPException(status_code=404, detail="Thread not found")

    metadata.pop("chat_profile", None)
    metadata.pop("chat_settings", None)
    metadata.pop("env", None)
    thread["metadata"] = metadata
    return JSONResponse(content=thread)


@router.get("/project/thread/{thread_id}/element/{element_id}")
async def get_thread_element(
    request: Request,
    thread_id: str,
    element_id: str,
    current_user: UserParam,
):
    """Get a specific thread element."""
    data_layer = get_data_layer()

    if not data_layer:
        raise HTTPException(status_code=400, detail="Data persistence is not enabled")

    if not current_user:
        raise HTTPException(status_code=401, detail="Unauthorized")

    await is_thread_author(current_user.identifier, thread_id)

    res = await data_layer.get_element(thread_id, element_id)
    return JSONResponse(content=res)


@router.put("/project/element")
async def update_thread_element(
    payload: ElementRequest,
    current_user: UserParam,
):
    """Update a specific thread element."""

    from chainlit.context import init_ws_context
    from chainlit.element import Element, ElementDict
    from chainlit.session import WebsocketSession

    session = WebsocketSession.get_by_id(payload.sessionId)
    context = init_ws_context(session)

    element_dict = cast(ElementDict, payload.element)

    if element_dict["type"] != "custom":
        return {"success": False}

    element = Element.from_dict(element_dict)

    if current_user:
        if (
            not context.session.user
            or context.session.user.identifier != current_user.identifier
        ):
            raise HTTPException(
                status_code=401,
                detail="You are not authorized to update elements for this session",
            )

    await element.update()
    return {"success": True}


@router.delete("/project/element")
async def delete_thread_element(
    payload: ElementRequest,
    current_user: UserParam,
):
    """Delete a specific thread element."""

    from chainlit.context import init_ws_context
    from chainlit.element import CustomElement, ElementDict
    from chainlit.session import WebsocketSession

    session = WebsocketSession.get_by_id(payload.sessionId)
    context = init_ws_context(session)

    element_dict = cast(ElementDict, payload.element)

    if element_dict["type"] != "custom":
        return {"success": False}

    element = CustomElement(
        id=element_dict["id"],
        object_key=element_dict["objectKey"],
        chainlit_key=element_dict["chainlitKey"],
        url=element_dict["url"],
        for_id=element_dict.get("forId") or "",
        thread_id=element_dict.get("threadId") or "",
        name=element_dict["name"],
        props=element_dict.get("props") or {},
        display=element_dict["display"],
    )

    if current_user:
        if (
            not context.session.user
            or context.session.user.identifier != current_user.identifier
        ):
            raise HTTPException(
                status_code=401,
                detail="You are not authorized to remove elements for this session",
            )

    await element.remove()

    return {"success": True}


@router.put("/project/thread")
async def rename_thread(
    request: Request,
    payload: UpdateThreadRequest,
    current_user: UserParam,
):
    """Rename a thread."""

    data_layer = get_data_layer()

    if not data_layer:
        raise HTTPException(status_code=400, detail="Data persistence is not enabled")

    if not current_user:
        raise HTTPException(status_code=401, detail="Unauthorized")

    thread_id = payload.threadId

    await is_thread_author(current_user.identifier, thread_id)

    await data_layer.update_thread(thread_id, name=payload.name)

    return JSONResponse(content={"success": True})


@router.put("/project/thread/share")
async def share_thread(
    request: Request,
    payload: ShareThreadRequest,
    current_user: UserParam,
):
    """Share or un-share a thread (author only)."""

    data_layer = get_data_layer()

    if not data_layer:
        raise HTTPException(status_code=400, detail="Data persistence is not enabled")

    if not current_user:
        raise HTTPException(status_code=401, detail="Unauthorized")

    thread_id = payload.threadId

    await is_thread_author(current_user.identifier, thread_id)

    # Fetch current thread and metadata, then toggle is_shared
    thread = await data_layer.get_thread(thread_id=thread_id)
    metadata = (thread.get("metadata") if thread else {}) or {}
    if isinstance(metadata, str):
        try:
            metadata = json.loads(metadata)
        except Exception:
            metadata = {}
    if not isinstance(metadata, dict):
        metadata = {}

    metadata = dict(metadata)
    is_shared = bool(payload.isShared)
    metadata["is_shared"] = is_shared
    if is_shared:
        metadata["shared_at"] = utc_now()
    else:
        metadata.pop("shared_at", None)
    try:
        await data_layer.update_thread(thread_id=thread_id, metadata=metadata)
        logger.debug(
            "[share_thread] updated metadata for thread=%s to %s",
            thread_id,
            metadata,
        )
    except Exception as e:
        logger.exception("[share_thread] update_thread failed: %s", e)
        raise

    return JSONResponse(content={"success": True})


@router.delete("/project/thread")
async def delete_thread(
    request: Request,
    payload: DeleteThreadRequest,
    current_user: UserParam,
):
    """Delete a thread."""

    data_layer = get_data_layer()

    if not data_layer:
        raise HTTPException(status_code=400, detail="Data persistence is not enabled")

    if not current_user:
        raise HTTPException(status_code=401, detail="Unauthorized")

    thread_id = payload.threadId

    await is_thread_author(current_user.identifier, thread_id)

    await data_layer.delete_thread(thread_id)
    return JSONResponse(content={"success": True})


@router.post("/project/action")
async def call_action(
    payload: CallActionRequest,
    current_user: UserParam,
):
    """Run an action."""

    from chainlit.action import Action
    from chainlit.context import init_ws_context
    from chainlit.session import WebsocketSession

    session = WebsocketSession.get_by_id(payload.sessionId)
    context = init_ws_context(session)
    config: ChainlitConfig = session.get_config()

    action = Action(**payload.action)

    if current_user:
        if (
            not context.session.user
            or context.session.user.identifier != current_user.identifier
        ):
            raise HTTPException(
                status_code=401,
                detail="You are not authorized to upload files for this session",
            )

    callback = config.code.action_callbacks.get(action.name)
    if callback:
        if not context.session.has_first_interaction:
            context.session.has_first_interaction = True
            asyncio.create_task(context.emitter.init_thread(action.name))

        response = await callback(action)
    else:
        raise HTTPException(
            status_code=404,
            detail=f"No callback found for action {action.name}",
        )

    return JSONResponse(content={"success": True, "response": response})


@router.post("/mcp")
async def connect_mcp(
    payload: ConnectMCPRequest,
    current_user: UserParam,
):
    from mcp import ClientSession
    from mcp.client.sse import sse_client
    from mcp.client.stdio import (
        StdioServerParameters,
        get_default_environment,
        stdio_client,
    )
    from mcp.client.streamable_http import streamablehttp_client

    from chainlit.context import init_ws_context
    from chainlit.mcp import (
        HttpMcpConnection,
        McpConnection,
        SseMcpConnection,
        StdioMcpConnection,
        validate_mcp_command,
    )
    from chainlit.session import WebsocketSession

    session = WebsocketSession.get_by_id(payload.sessionId)
    context = init_ws_context(session)
    config: ChainlitConfig = session.get_config()

    if current_user:
        if (
            not context.session.user
            or context.session.user.identifier != current_user.identifier
        ):
            raise HTTPException(
                status_code=401,
            )

    mcp_enabled = config.features.mcp.enabled
    if mcp_enabled:
        if payload.name in session.mcp_sessions:
            old_client_session, old_exit_stack = session.mcp_sessions[payload.name]
            if on_mcp_disconnect := config.code.on_mcp_disconnect:
                await on_mcp_disconnect(payload.name, old_client_session)
            try:
                await old_exit_stack.aclose()
            except Exception:
                pass

        try:
            exit_stack = AsyncExitStack()
            mcp_connection: McpConnection

            if payload.clientType == "sse":
                if not config.features.mcp.sse.enabled:
                    raise HTTPException(
                        status_code=400,
                        detail="SSE MCP is not enabled",
                    )

                mcp_connection = SseMcpConnection(
                    url=payload.url,
                    name=payload.name,
                    headers=getattr(payload, "headers", None),
                )

                transport = await exit_stack.enter_async_context(
                    sse_client(
                        url=mcp_connection.url,
                        headers=mcp_connection.headers,
                    )
                )
            elif payload.clientType == "stdio":
                if not config.features.mcp.stdio.enabled:
                    raise HTTPException(
                        status_code=400,
                        detail="Stdio MCP is not enabled",
                    )

                env_from_cmd, command, args = validate_mcp_command(payload.fullCommand)
                mcp_connection = StdioMcpConnection(
                    command=command, args=args, name=payload.name
                )

                env = get_default_environment()
                env.update(env_from_cmd)
                # Create the server parameters
                server_params = StdioServerParameters(
                    command=command, args=args, env=env
                )

                transport = await exit_stack.enter_async_context(
                    stdio_client(server_params)
                )

            elif payload.clientType == "streamable-http":
                if not config.features.mcp.streamable_http.enabled:
                    raise HTTPException(
                        status_code=400,
                        detail="HTTP MCP is not enabled",
                    )
                mcp_connection = HttpMcpConnection(
                    url=payload.url,
                    name=payload.name,
                    headers=getattr(payload, "headers", None),
                )
                transport = await exit_stack.enter_async_context(
                    streamablehttp_client(
                        url=mcp_connection.url,
                        headers=mcp_connection.headers,
                    )
                )

            # The transport can return (read, write) for stdio, sse
            # Or (read, write, get_session_id) for streamable-http
            # We are only interested in the read and write streams here.
            read, write = transport[:2]

            mcp_session: ClientSession = await exit_stack.enter_async_context(
                ClientSession(
                    read_stream=read, write_stream=write, sampling_callback=None
                )
            )

            # Initialize the session
            await mcp_session.initialize()

            # Store the session
            session.mcp_sessions[mcp_connection.name] = (mcp_session, exit_stack)

            # Call the callback
            if config.code.on_mcp_connect:
                await config.code.on_mcp_connect(mcp_connection, mcp_session)

        except Exception as e:
            raise HTTPException(
                status_code=400,
                detail=f"Could not connect to the MCP: {e!s}",
            )
    else:
        raise HTTPException(
            status_code=400,
            detail="This app does not support MCP.",
        )

    tool_list = await mcp_session.list_tools()

    return JSONResponse(
        content={
            "success": True,
            "mcp": {
                "name": payload.name,
                "tools": [{"name": t.name} for t in tool_list.tools],
                "clientType": payload.clientType,
                "command": payload.fullCommand
                if payload.clientType == "stdio"
                else None,
                "url": getattr(payload, "url", None)
                if payload.clientType in ["sse", "streamable-http"]
                else None,
                # Include optional headers for SSE and streamable-http connections
                "headers": getattr(payload, "headers", None)
                if payload.clientType in ["sse", "streamable-http"]
                else None,
            },
        }
    )


@router.delete("/mcp")
async def disconnect_mcp(
    payload: DisconnectMCPRequest,
    current_user: UserParam,
):
    from chainlit.context import init_ws_context
    from chainlit.session import WebsocketSession

    session = WebsocketSession.get_by_id(payload.sessionId)
    context = init_ws_context(session)

    if current_user:
        if (
            not context.session.user
            or context.session.user.identifier != current_user.identifier
        ):
            raise HTTPException(
                status_code=401,
            )

    callback = config.code.on_mcp_disconnect
    if payload.name in session.mcp_sessions:
        try:
            client_session, exit_stack = session.mcp_sessions[payload.name]
            if callback:
                await callback(payload.name, client_session)

            try:
                await exit_stack.aclose()
            except Exception:
                pass
            del session.mcp_sessions[payload.name]

        except Exception as e:
            raise HTTPException(
                status_code=400,
                detail=f"Could not disconnect to the MCP: {e!s}",
            )

    return JSONResponse(content={"success": True})


@router.post("/project/file")
async def upload_file(
    current_user: UserParam,
    session_id: str,
    file: UploadFile,
    ask_parent_id: Optional[str] = None,
):
    """Upload a file to the session files directory."""

    from chainlit.session import WebsocketSession

    session = WebsocketSession.get_by_id(session_id)

    if not session:
        raise HTTPException(
            status_code=404,
            detail="Session not found",
        )

    if current_user:
        if not session.user or session.user.identifier != current_user.identifier:
            raise HTTPException(
                status_code=401,
                detail="You are not authorized to upload files for this session",
            )

    session.files_dir.mkdir(exist_ok=True)

    try:
        content = await file.read()

        assert file.filename, "No filename for uploaded file"
        assert file.content_type, "No content type for uploaded file"

        spec: AskFileSpec = session.files_spec.get(ask_parent_id, None)
        if not spec and ask_parent_id:
            raise HTTPException(
                status_code=404,
                detail="Parent message not found",
            )

        try:
            validate_file_upload(file, spec=spec)
        except ValueError as e:
            raise HTTPException(status_code=400, detail=str(e))

        file_response = await session.persist_file(
            name=file.filename, content=content, mime=file.content_type
        )

        return JSONResponse(content=file_response)
    finally:
        await file.close()


def validate_file_upload(file: UploadFile, spec: Optional[AskFileSpec] = None):
    """Validate the file upload as configured in config.features.spontaneous_file_upload or by AskFileSpec
    for a specific message.

    Args:
        file (UploadFile): The file to validate.
        spec (AskFileSpec): The file spec to validate against if any.
    Raises:
        ValueError: If the file is not allowed.
    """
    if not spec and config.features.spontaneous_file_upload is None:
        """Default for a missing config is to allow the fileupload without any restrictions"""
        return

    if not spec and not config.features.spontaneous_file_upload.enabled:
        raise ValueError("File upload is not enabled")

    validate_file_mime_type(file, spec)
    validate_file_size(file, spec)


def validate_file_mime_type(file: UploadFile, spec: Optional[AskFileSpec]):
    """Validate the file mime type as configured in config.features.spontaneous_file_upload.
    Args:
        file (UploadFile): The file to validate.
    Raises:
        ValueError: If the file type is not allowed.
    """

    if not spec and (
        config.features.spontaneous_file_upload is None
        or config.features.spontaneous_file_upload.accept is None
    ):
        "Accept is not configured, allowing all file types"
        return

    accept = config.features.spontaneous_file_upload.accept if not spec else spec.accept

    assert isinstance(accept, List) or isinstance(accept, dict), (
        "Invalid configuration for spontaneous_file_upload, accept must be a list or a dict"
    )

    if isinstance(accept, List):
        for pattern in accept:
            if fnmatch.fnmatch(str(file.content_type), pattern):
                return
    elif isinstance(accept, dict):
        for pattern, extensions in accept.items():
            if fnmatch.fnmatch(str(file.content_type), pattern):
                if len(extensions) == 0:
                    return
                for extension in extensions:
                    if file.filename is not None and file.filename.lower().endswith(
                        extension.lower()
                    ):
                        return
    raise ValueError("File type not allowed")


def validate_file_size(file: UploadFile, spec: Optional[AskFileSpec]):
    """Validate the file size as configured in config.features.spontaneous_file_upload.
    Args:
        file (UploadFile): The file to validate.
    Raises:
        ValueError: If the file size is too large.
    """
    if not spec and (
        config.features.spontaneous_file_upload is None
        or config.features.spontaneous_file_upload.max_size_mb is None
    ):
        return

    max_size_mb = (
        config.features.spontaneous_file_upload.max_size_mb
        if not spec
        else spec.max_size_mb
    )
    if file.size is not None and file.size > max_size_mb * 1024 * 1024:
        raise ValueError("File size too large")


@router.get("/project/file/{file_id}")
async def get_file(
    file_id: str,
    session_id: str,
    current_user: UserParam,
):
    """Get a file from the session files directory."""
    from chainlit.session import WebsocketSession

    session = WebsocketSession.get_by_id(session_id) if session_id else None

    if not session:
        raise HTTPException(
            status_code=401,
            detail="Unauthorized",
        )

    if current_user:
        if not session.user or session.user.identifier != current_user.identifier:
            raise HTTPException(
                status_code=401,
                detail="You are not authorized to download files from this session",
            )

    if file_id in session.files:
        file = session.files[file_id]
        return FileResponse(file["path"], media_type=file["type"])
    else:
        raise HTTPException(status_code=404, detail="File not found")


@router.get("/favicon")
async def get_favicon():
    """Get the favicon for the UI."""
    custom_favicon_path = os.path.join(APP_ROOT, "public", "favicon.*")
    files = glob.glob(custom_favicon_path)

    if files:
        favicon_path = files[0]
    else:
        favicon_path = os.path.join(build_dir, "favicon.svg")

    media_type, _ = mimetypes.guess_type(favicon_path)

    return FileResponse(favicon_path, media_type=media_type)


@router.get("/logo")
async def get_logo(theme: Optional[Theme] = Query(Theme.light)):
    """Get the default logo for the UI."""
    theme_value = theme.value if theme else Theme.light.value
    logo_path = None

    for path in [
        os.path.join(APP_ROOT, "public", f"logo_{theme_value}.*"),
        os.path.join(build_dir, "assets", f"logo_{theme_value}*.*"),
    ]:
        files = glob.glob(path)

        if files:
            logo_path = files[0]
            break

    if not logo_path:
        logo_path = os.path.join(
            os.path.dirname(__file__),
            "frontend",
            "dist",
            f"logo_{theme_value}.svg",
        )
        logger.info("Missing custom logo. Falling back to default logo.")

    media_type, _ = mimetypes.guess_type(logo_path)

    return FileResponse(logo_path, media_type=media_type)


@router.get("/avatars/{avatar_id:str}")
async def get_avatar(avatar_id: str):
    """Get the avatar for the user based on the avatar_id."""
    if not re.match(r"^[a-zA-Z0-9_ .-]+$", avatar_id):
        raise HTTPException(status_code=400, detail="Invalid avatar_id")

    if avatar_id == "default":
        avatar_id = config.ui.name

    avatar_id = avatar_id.strip().lower().replace(" ", "_").replace(".", "_")

    base_path = Path(APP_ROOT) / "public" / "avatars"
    avatar_pattern = f"{avatar_id}.*"

    matching_files = base_path.glob(avatar_pattern)

    if avatar_path := next(matching_files, None):
        if not is_path_inside(avatar_path, base_path):
            raise HTTPException(status_code=400, detail="Invalid filename")
        media_type, _ = mimetypes.guess_type(str(avatar_path))

        return FileResponse(avatar_path, media_type=media_type)

    return await get_favicon()


@router.head("/")
def status_check():
    """Check if the site is operational."""
    return {"message": "Site is operational"}


@router.get("/{full_path:path}")
async def serve(request: Request):
    """Serve the UI files."""
    root_path = os.getenv("CHAINLIT_PARENT_ROOT_PATH", "") + os.getenv(
        "CHAINLIT_ROOT_PATH", ""
    )
    html_template = get_html_template(root_path)
    response = HTMLResponse(content=html_template, status_code=200)

    return response


app.include_router(router)

import chainlit.socket  # noqa
